abstract type Model end
struct ProgramFE <: Model end
struct ProgramRE <: Model end


function checkValidity(constants,lv)
    @assert length(unique(constants.typ)) == 4
    #@assert all(constants.A_lb .<= lv.A .<= constants.A_ub)
	@assert all(lv.Astar[constants.A_lb .== 1] .> 0)
	@assert all(lv.Astar[constants.A_ub .== 0] .< 0)
    @assert all(constants.feasible .>= constants.feasibleUnranked)
    @assert all(constants.J_lb[.!constants.feasible] .== -1 .== constants.J_ub[.!constants.feasible])
end

function getX(MCMC,ii,Xp,scores)
	offplatform = 1 .- Array{Float64,1}(vec(MCMC["platform"]))
	scholarship = MCMC["X_oo"][ii,end]::Float64
	(ii-1)%10000 == 0 && print(" $ii ")
	J = Int(MCMC["J"])
	netprice::Array{Float64,1} = MCMC["Pij"][ii,:]
	netprice_teaching = netprice.*Xp
	netprice_nonteaching = netprice.*(1 .- Xp)   #maximum(Pij,dims=1)
	sameregion = MCMC["X_dist_ij"][ii,:]
	interact::Array{Float64,2} = hcat([MCMC["X_interact_j"] .* MCMC["X_interact_i"][ii,n] for n=1:size(MCMC["X_interact_i"],2)]...)
	testScoreSlopes = scores[ii] .* Xp
	testScore2 = scores[ii]^2 .* Xp
	testScore3 = scores[ii]^3 .* Xp
	#X = hcat(netprice, netprice.*(1 .- Xp), sameregion, interact, testScoreSlopes,)
	X = hcat(netprice_teaching, sameregion, interact, testScoreSlopes, testScore2, testScore3)
	X
end



getXfixed(MCMC,::ProgramRE) = MCMC["X_fixed"]
function getXfixed(MCMC,model::ProgramFE,ind=4)  #from Matlab: typeIndex =[area major inst program]
	#offplatform = 1 .- Array{Float64,1}(vec(MCMC["platform"]))
	programIDs = Array{Int,1}(MCMC["TypeIndex"][:,ind])
	Xprogram = zeros(length(programIDs),maximum(programIDs))
	for (j,id) in enumerate(programIDs)
		Xprogram[j,id] = 1
	end
	#Xfixed = hcat(offplatform,Xprogram)
	Xprogram
end

function restrictToRelevantSet(MCMC,mysample=1:Int(MCMC["I"]),bw_upper=2500,bw_lower=2500,_AppsAll = MCMC["Apps"])
    AppsAll = Array{Int,2}(replace(_AppsAll[mysample,:],NaN=>0))
    eligible = MCMC["EligibleToApply"][mysample,:]
    score = MCMC["Priorities"][mysample,:]
	platform = BitArray{1}(vec(MCMC["platform"]))
    scorediff = score .- MCMC["cutoff"]' #nonnegative if would be admitted
    clearly_infeasible = platform' .& (scorediff .< -abs.(bw_lower)) #.& ( (score .- MCMC["wlcutoff"]') .< 0)
    clearly_feasible = platform' .& (scorediff .>= bw_upper)
    AppsRelevant = zeros(Int,size(AppsAll))
    for ii=1:length(mysample)
        r_relevant = 0
        for rr=1:size(AppsAll,2)
            if AppsAll[ii,rr] > 0
                jj = AppsAll[ii,rr]
                if eligible[ii,jj] && !clearly_infeasible[ii,jj]
                    r_relevant += 1
                    AppsRelevant[ii,r_relevant] = jj
                end
                clearly_feasible[ii,jj] && break #stops after 1st safety school
            end
        end
    end
    AppsRelevant, AppsAll, clearly_infeasible, clearly_feasible
end

function setupXfrictionFull(G8,platform,sameregion)
	G25_on = .!(G8) .* platform
	G8_on = G8 .* platform
	G8_off = G8 .* (.!(platform))
	#ps = [G25_on G8_on G8_off]
	[[G25_on G8_on G8_off sameregion[ii,:]] for ii=1:size(sameregion,1)]
end


function setupFrictionCovariates(inds_WL_typ, Xfriction_full)
	Xfriction_typ = zeros(length(inds_WL_typ),size(Xfriction_full[1],2))
	for (n,ind) in enumerate(inds_WL_typ)
		jj=ind[1]; ii=ind[2]
		Xfriction_typ[n,:] = Xfriction_full[ii][jj,:]
	end
	Xfriction_typ
end




function setupData(MCMC,year,mysample=1:Int(MCMC["I"]),model::Model=ProgramFE(),encodeXs=true)
    I=length(mysample)
    J=Int(MCMC["J"])
    R=Int(MCMC["R"])
	institutionID = vec(Int.(MCMC["TypeIndex"][:,3]))
	G8 = institutionID .> 25
	platform = BitArray{1}(vec(MCMC["platform"]))
	sameregion = MCMC["X_dist_ij"][mysample,:]
    #nRC = Int(MCMC["nrc"])
    ntypes = length(unique(MCMC["Type0"]))
    println("\n setupData: bounds")
	AppsRelevant,AppsAll,clearlyInfeasible,clearlyFeasible = restrictToRelevantSet(MCMC,mysample)
	(J_lb,J_ub,OO_lb,OO_ub,waitlistMightCall,A_lb,A_ub,feasibleUnranked) = getBoundStructures(MCMC,AppsRelevant,AppsAll,clearlyInfeasible,clearlyFeasible,mysample)
	types = vec(Int.(MCMC["Type0"]))[mysample]
	J_ub[.!platform,:] .= -1
	#inds_WL_off = [findall((types.==typ)' .& waitlistMightCall .& .!platform) for typ=1:ntypes] #indices of places whose availability is random
	#inds_WL_on = [findall((types.==typ)' .& waitlistMightCall .& platform) for typ=1:ntypes]
	println("setupData: WL and off-platform admissions shifters")
	inds_WL = [findall((types.==typ)' .& waitlistMightCall) for typ=1:ntypes] #(jj,ii)
	if encodeXs #!encodeXs => skip this step.  Want to skip if we are just doing model fit, not estimating or simulating.
		Xfriction_full = setupXfrictionFull(G8,platform,sameregion)
		Xfriction = [setupFrictionCovariates(inds_wl_typ,Xfriction_full) for inds_wl_typ in inds_WL]
	else
		Xfriction_full = Float64[]
		Xfriction = Float64[]
	end
	programIDs = Array{Int,1}(MCMC["TypeIndex"][:,4]) #program ID for program-level random effects, consistent acros years
	println("setupData: getting program indicators")
	Xfixed = getXfixed(MCMC,model::ProgramFE)[:,1:end-1]
	scores = (MCMC["X_oo"][:,2] .+ MCMC["X_oo"][:,3])./2 #mean math/reading
	if encodeXs
		print("setupData: getting match-level covariates: ")
		Xp = (MCMC["TypeIndex"][:,1] .== 7)
		Xij = [zeros(0,0) for ii in mysample]
		Threads.@threads for mm=1:length(mysample)
			Xij[mm] = getX(MCMC,mysample[mm],Xp,scores)
		end
	else
		println("Skipping construction of match-level indicators")
		Xij = [zeros(0,0) for ii in mysample]
	end
	Xoo = [MCMC["X_oo"][mysample,:] scores[mysample].^2 scores[mysample].^3]
	constants = (
		year=year,
		sample=mysample,
		specialAdmit=falses(I),
		Grad6 = Array{Bool,1}(MCMC["grad6"][mysample]),
		A_ub=A_ub,
		A_lb=A_lb,
        J_ub=J_ub,
		J_lb=J_lb,
		OO_ub=OO_ub,
		OO_lb=OO_lb,
        typ=types,
		ntypes=ntypes,
		waitlistMightCall=waitlistMightCall,
        feasibleUnranked=feasibleUnranked, #used in estimation to know which programs to loop over when bound is "best unlisted thing"
		feasible = collect(.!clearlyInfeasible'), #used in cleanup and starting vals
        platform = platform,
		AppsRelevant = AppsRelevant, #needed for starting values only
		AppsAll = AppsAll,
        enroll = Array{Int,1}(vec(MCMC["enrollmentIndex"]))[mysample],
        Xij = Xij, Xrc = MCMC["X_rc"], Xfixed=Xfixed, Xoo=Xoo,
		Xfriction=Xfriction, Xfriction_full=Xfriction_full, #"full" is (J,I)
		inds_WL=inds_WL, programIDs=programIDs,G8=G8,
        )
    return constants
end

function getLatentVars(cc)
	(J,I) = size(cc.J_lb)
    nRC = size(cc.Xrc,2)
    ntypes = cc.ntypes
	nmatch = size(cc.Xij[1],2)
    nfix = size(cc.Xfixed,2)
    noo = size(cc.Xoo,2)
	#
	Xoutcome = Array{Float64,2}[]
	indEnroll = zeros(Int,I)
	for typ=1:ntypes
		inds = findall( (cc.typ .== typ) .& (cc.enroll .> 0) .& .!(cc.specialAdmit))
		push!(Xoutcome,zeros(length(inds),1+noo+nmatch+nfix))
		for (ix,ii) in enumerate(inds)
			indEnroll[ii] = ix
		end
	end
	#
	latentvars = (U=zeros(J,I), Astar=zeros(J,I), U0=zeros(2,I),
				rc = zeros(nRC,I), H=zeros(I))
	temps = (mu=zeros(J,I), mu0=zeros(2,I), ufe=zeros(J,ntypes),
			ure=zeros(J,ntypes), umatch=zeros(J,I), urc=zeros(J,I),
			Xoutcome = Xoutcome, indEnroll = indEnroll,
			Astar_typ = [zeros(length(inds_WL_typ)) for inds_WL_typ in cc.inds_WL],
			)
	return latentvars, temps
end

""" startvals!(constants,latentvars): set latent utility and availablity to feasible starting values """
function startvals!(cc,lv,temps)
	println("getting starting values")
    (J,I) = size(lv.U)
	lv.U .= .0001 .* rand.()
	lv.U0[1,:] .= -0.001 .+(cc.OO_lb .==-2).*0.002
	lv.U0[2,:] .= 0.002
	lv.rc .= rand.(Normal())
    lv.Astar .= 2.0 .* cc.A_ub .- 1.0  #start w/ everything feasible that could possibly be feasible
    Threads.@threads for ii=1:I
		(ii-1)%10000 == 0 && print(" $ii ")
        for rr=1:size(cc.AppsRelevant,2)
			jj = cc.AppsRelevant[ii,rr]
			jj==0 && break
			lv.U[jj,ii] = 1/rr
		end
        lv.U[cc.J_ub[:,ii].== 0, ii] .= -1.0 #things dominated by outside option are worth -1
        #lv.U0[1,ii] = 0.001
        #if enrolled option is oo or offplatform, it gets utility 2 to ensure it dominates onplatform stuff
        enroll = cc.enroll[ii]
        if enroll == 0 #enrolled in oo
            lv.U0[2,ii] = 2
        elseif cc.J_lb[enroll,ii] == cc.J_ub[enroll,ii] == -1 #!cc.platform[enroll] or enrolled via non-platform channel
            lv.U[enroll,ii] = 2
        else #enrolled on-platform
			nothing
		end
		if cc.Grad6[ii]
			lv.H[ii] = 1.0
		else
			lv.H[ii] = -0.5
		end
    end
	fillXoutcome!(cc,lv,temps)
end

#recursively bound utilities of feasible schools: things bounded above by a program with u=1/n have u=1/(n+1)
# function _initialU!(cc,lv,counter,j,ii)
# 	lv.U[j,ii] = 1/counter
# 	@views js = findall(cc.J_ub[:,ii].==j) #find schools whose ub is previous school
# 	for j2 in js
# 		_initialU!(cc,lv,counter+1,j2,ii)
# 	end
# 	return
# end

function cleanEnrollment!(cc)
	I = size(cc.J_lb,2)
	for ii=1:I
		e = cc.enroll[ii]
		if e > 0 && cc.platform[e] && (!cc.feasible[e,ii] || cc.feasibleUnranked[e,ii])  #cc.feasible iff e is not ex-ante clearly infeasible
			if !cc.feasible[e,ii]
				mymsg = "it wasn't feasible"
			else
				mymsg = "it was feasible and unranked"
			end
			println("person $ii enrolled in on-platform program $e but $mymsg")
			@assert cc.A_lb[e,ii] == 1
			@assert cc.A_ub[e,ii] == 1
			cc.J_lb[e,ii] = -1
			cc.J_ub[e,ii] = -1
			cc.feasibleUnranked[e,ii] = false
			cc.specialAdmit[ii] = true
		else
			cc.specialAdmit[ii] = false
		end
	end
end


function makeAdditionalConstants(cc,lv)
	if !haskey(cc,:Xij_t_Xij)
	    Xij_t_Xij = Array{Float64,2}[]
	    for typ = 1:cc.ntypes
			println("precomputing term for :updateMatchTerms, type $typ")
	        inds = findall(cc.typ .== typ)
	        push!(Xij_t_Xij, sum(cc.Xij[ii]'*cc.Xij[ii] for ii in inds))
	    end
		return NamedTuple{(keys(cc)...,:Xij_t_Xij)}([[cc[k] for k in keys(cc)]...,Xij_t_Xij])
	else
		return cc
	end
end

#convenience function
function reinitialize!(dataset,model::Model)
    (constants,lv,temps) = dataset[2011]
    nt = constants.ntypes
	noo = size(constants.Xoo,2)
    (J,nbetaij) = size(constants.Xij[1])
    @assert size(constants.Xfixed,1) == J
    ndelta = size(constants.Xfixed,2)
    uniquePrograms = union([dataset[year][1].programIDs for year in keys(dataset)]...)
    nPrograms = length(uniquePrograms)
    @assert sort(uniquePrograms) == 1:maximum(uniquePrograms)
	if isa(model,ProgramRE)
		sigsqProgramRE = ones(nt)
	else
		sigsqProgramRE = zeros(nt)
	end
	nalpha = size(constants.Xfriction_full[1],2)
	theta = (
        alpha=.5 .* ones(nalpha,nt), #on-platform aftermarket missed-call rate
        beta_ij = .01 .* randn(nbetaij,nt), #coeff on observed match terms
        programRE = randn(nPrograms,nt), #program RE
        beta_fixed = .01 .* randn(ndelta,nt), #coeff on course content, inst. dummies, major dummies.
        sigsqProgramRE = sigsqProgramRE, # variance of program effects
        Sigma_rc = cat([diagm(ones(2)) for t=1:nt]...,dims=3),
		betaOO0 = zeros(noo,nt),  #ex-ante outside option mean
		betaOO1 = zeros(noo,nt),  #ex-post outside option mean
		sigsqOO0 = ones(nt), #ex-ante outside option variance
		sigsqOO1 = ones(nt), #ex-post outside option variance
		betaOutcome = .01 .* randn(1+noo+nbetaij+ndelta,nt),
	)
	for year in keys(dataset)
		startvals!(dataset[year]...)
	end
	return theta
end

function fillXoutcome!(cc,lv,temps)
	(J,II) = size(lv.U)
	nmatch = size(cc.Xij[1],2)
	nfix = size(cc.Xfixed,2)
	noo = size(cc.Xoo,2)
	for ii=1:II
        e = cc.enroll[ii]
		typei = cc.typ[ii]
        if e > 0 && !cc.specialAdmit[ii]
			Xoutcome = temps.Xoutcome[typei]
			ind = temps.indEnroll[ii]
            #Xoutcome[ind,1] = log(lv.U0[1,ii])
            Xoutcome[ind,1] = lv.U[e,ii]
            Xoutcome[ind,2:noo+1] .= view(cc.Xoo,ii,:)
            Xoutcome[ind,noo+2: noo+nmatch+1] .= view(cc.Xij[ii],e,:)
            Xoutcome[ind,noo+nmatch+2:end] .= view(cc.Xfixed,e,:)
        end
    end
end
